#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys
from src.constant.file_and_path_constant import FileAndPathConstant

sys.path.append(FileAndPathConstant.System_Drive + '/github-repository/hades/dev-project/jormungandr')

import numpy as np
import math
import ctypes
import os
from numba import jit
# from numba import cuda, vectorize, guvectorize
# cuda.select_device(0)
from openpyxl import load_workbook
from sklearn.neighbors import KNeighborsClassifier as KNN
from sklearn.metrics import accuracy_score
from src.manager.log_manager import LogManager
from src.util.list_util import ListUtil
from src.constant.file_and_path_constant import FileAndPathConstant
from src.util.double_util import DoubleUtil
from src.manager.dll_manager import DllManager
from test.knn.point import Point as Point
from test.knn.point_struct import PointStruct

Logger = LogManager.get_logger(__name__)


class KNN2Test:
    # excel文件路径
    File_Path = "F:/github-repository/dataset/DryBeanDataset/DryBeanDataset/Dry_Bean_Dataset.xlsx"

    # 训练数据和测试数据比例
    Training_Dataset_Percentage = 0.8
    Testing_Dataset_Percentage = 0.2

    # 标签
    Label_Dict = {'SEKER': '1', 'BARBUNYA': '2', 'BOMBAY': '3', 'CALI': '4', 'DERMASON': '5', 'HOROZ': '6', 'SIRA': '7'}

    # 近邻数
    Neighbors = 3

    # 单个样本数据
    Single_Sample_Data = [['0.08407914738082632', '0.14776802402244507', '0.12050714190244108', '0.24134827654954272',
                           '0.14448574307519568', '0.521878107177938', '0.08228315132926864', '0.15863789290307861',
                           '0.6594086360936471', '0.8967091609975719', '0.8570377557686293', '0.754266803793188',
                           '0.45190032628038146', '0.6408568790070815', '0.7147889820523273', '0.9912080969458209']]

    # 单个样本标签
    Single_Sample_Label = ['1']

    def __init__(self):
        # 通过环境变量确认数据文件的路径
        user_do_main = os.environ.get("USERDOMAIN")
        if user_do_main == "LS2DL0YFY3IVX3W":
            KNN2Test.File_Path = "F:/github-repository/dataset/DryBeanDataset/DryBeanDataset/Dry_Bean_Dataset.xlsx"
        else:
            KNN2Test.File_Path = FileAndPathConstant.System_Drive + "/github-repository/dataset/DryBeanDataset/DryBeanDataset/Dry_Bean_Dataset.xlsx"
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    def read_file(self, file_path: str) -> list:
        """
        从文件中读取数据
        """
        Logger.info("开始从文件【" + file_path + "】中读取数据")

        workbook = load_workbook(file_path)
        worksheet = workbook.worksheets[0]

        # 打乱数据
        all_data_set = set()
        for index, row_tuple in enumerate(worksheet.values):
            # 第一行为标题
            if index == 0:
                continue
            all_data_set.add(row_tuple)
        return ListUtil.set_to_list(all_data_set)

    def differentiate_training_dataset_and_testing_dataset(self, normal_all_data_list) -> list:
        """
        区分训练数据和测试数据
        """
        Logger.info("开始划分训练数据和测试数据")

        # 计算训练数据集行数和测试数据集行数
        row_number = len(normal_all_data_list)
        training_dataset_row_number = round(row_number * KNN2Test.Training_Dataset_Percentage)
        testing_dataset_row_number = row_number - training_dataset_row_number

        # 分配训练数据和测试数据、标签
        training_dataset_list = list()
        testing_dataset_list = list()
        training_dataset_label_list = list()
        testing_dataset_label_list = list()
        for index, row_list in enumerate(normal_all_data_list):
            if index + 1 <= training_dataset_row_number:
                column_list = list()
                for index, value in enumerate(row_list):
                    if index != (len(row_list) - 1):
                        column_list.append(value)
                training_dataset_list.append(column_list)
                training_dataset_label_list.append(row_list[len(row_list) - 1])
            else:
                column_list = list()
                for index, value in enumerate(row_list):
                    if index != (len(row_list) - 1):
                        column_list.append(value)
                testing_dataset_list.append(column_list)
                testing_dataset_label_list.append(row_list[len(row_list) - 1])
        return training_dataset_list, training_dataset_label_list, testing_dataset_list, testing_dataset_label_list

    def normalize(self, all_data_list) -> list:
        """
        归一化
        """
        Logger.info("开始归一化")

        temp_all_data_list = list()

        # 行列互换
        reverse_all_data_list = list(map(list, zip(*all_data_list)))
        for row_index, row_data_list in enumerate(reverse_all_data_list):
            # 求最大值和最小值
            max_value = max(row_data_list)
            min_value = min(row_data_list)

            # 最后的元素是标签，将其转换为数字
            if row_index == (len(reverse_all_data_list) - 1):
                label_list = list()
                for index, column_data in enumerate(row_data_list):
                    label_list.append(KNN2Test.Label_Dict[row_data_list[index]])
                temp_all_data_list.append(label_list)
                continue

            # 归一化
            row_list = list()
            for column_index, column_data in enumerate(row_data_list):
                row_list.append(str((column_data - min_value) / (max_value - min_value)))
            temp_all_data_list.append(row_list)

        # 行列互换
        normal_all_data_list = list(map(list, zip(*temp_all_data_list)))
        return normal_all_data_list

    # @jit
    # @guvectorize('str[:,:], str[:], str[:,:], str[:]', '(n),()->(n)')
    def classify(self, normal_training_dataset, normal_training_dataset_label, normal_testing_dataset,
                 normal_testing_dataset_label):
        """
        执行k近邻算法
        """
        Logger.info("开始执行k近邻算法")

        # 计算每一个测试数据集中的样本距离每一个训练数据集中样本的欧式距离。
        # all_distance_list的第一行表示测试数据集的第一个样本距离训练数据集中每个样本的距离
        all_distance_list = list()

        ######################################## python调用c函数开始 ######################################
        # dll_manager = DllManager()
        # dll = dll_manager.get_dll()
        #
        # dll.Classify.argtypes = [
        #     np.ctypeslib.ndpointer(dtype=np.float, ndim=2,
        #                            shape=(len(normal_training_dataset), len(normal_training_dataset[0]))),
        #     np.ctypeslib.ndpointer(dtype=np.float, ndim=1,
        #                            shape=(1, len(normal_training_dataset_label))),
        #     np.ctypeslib.ndpointer(dtype=np.float, ndim=2,
        #                            shape=(len(normal_testing_dataset), len(normal_testing_dataset[0]))),
        #     np.ctypeslib.ndpointer(dtype=np.float, ndim=1,
        #                            shape=(1, len(normal_testing_dataset_label))),
        #     np.ctypeslib.ndpointer(dtype=np.float, ndim=2,
        #                            shape=(len(normal_testing_dataset), len(normal_training_dataset)))
        # ]
        #
        # dll.Classify.restype = None
        # # point_struct = PointStruct()
        # point_struct_list = [[PointStruct() for i in range(len(normal_testing_dataset))] for j in
        #                      range(len(normal_training_dataset))]
        # dll.Classify(np.array(normal_training_dataset), np.array(normal_training_dataset_label),
        #              np.array(normal_testing_dataset), np.array(normal_testing_dataset_label),
        #              np.array(point_struct_list))
        # Logger.info("")
        ######################################## python调用c函数结束 ######################################

        for testing_dataset_row_index, testing_dataset_row_value in enumerate(normal_testing_dataset):
            # 表示测试数据集中的某一条样本距离训练数据集中所有样本的距离
            distance_list = list()
            for training_dataset_row_index, training_dataset_row_value in enumerate(normal_training_dataset):
                pow_sum = 0.0
                for column_index, column_value in enumerate(training_dataset_row_value):
                    pow_sum += math.pow(column_value - testing_dataset_row_value[column_index], 2)
                sqrt = math.sqrt(pow_sum)
                distance_list.append(Point(testing_dataset_row_index, training_dataset_row_index, sqrt))
            # 降序排列
            distance_list.sort(key=lambda point: point.distance, reverse=False)
            all_distance_list.append(distance_list)
        # Logger.info(all_distance_list)

        # 确定所有测试数据集的分类
        correct_count = 0
        can_not_predict_count = 0
        can_predict_count = 0
        for row_index, row_value in enumerate(all_distance_list):
            count_dict = {}
            for i in range(KNN2Test.Neighbors):
                if normal_training_dataset_label[row_value[i].training_dataset_index] not in count_dict.keys():
                    count_dict[normal_training_dataset_label[row_value[i].training_dataset_index]] = 1
                else:
                    count_dict[normal_training_dataset_label[row_value[i].training_dataset_index]] += 1

            # {3.0: 3}, value:3.0, key:0
            # 根据dict的value值降序排列
            count_tuple_order_by_desc = sorted(count_dict.items(), key=lambda item: item[1], reverse=True)

            # 预测的类型
            predict_label = str()
            if len(count_tuple_order_by_desc) == 0:
                predict_label = str(count_tuple_order_by_desc[0][0])
                Logger.info("类型为【" + predict_label + "】")
            elif count_tuple_order_by_desc[0][1] == count_tuple_order_by_desc[1][1]:
                Logger.info("无法确认类型，忽略")
                predict_label = None
            else:
                predict_label = str(count_tuple_order_by_desc[0][0])
                Logger.info("类型为【" + predict_label + "】")

            # 判断预测是否正确
            if predict_label is None:
                can_not_predict_count += 1
            if predict_label is not None and normal_testing_dataset_label[row_index] == predict_label:
                correct_count += 1
                can_predict_count += 1

        Logger.info("无法预测的样本数量为【%s】，可以预测的样本数量为【%s】，正确率为【%s】",
                    can_not_predict_count, can_predict_count, str(correct_count / len(normal_testing_dataset_label)))

        # 模型拟合
        # knn = KNN(n_neighbors=KNN2Test.Neighbors)
        # # knn.fit(DoubleUtil.str_2dlist_to_float_2dlist(normal_training_dataset),
        # #         DoubleUtil.str_2dlist_to_float_2dlist(normal_training_dataset_label))
        # knn.fit(normal_training_dataset, normal_training_dataset_label)
        #
        # # 模型预测
        # y_pred = knn.predict(normal_testing_dataset)
        # accuracy = accuracy_score(y_pred, normal_testing_dataset_label)
        # Logger.info("算法的正确率：" + str(float(accuracy) * 100) + "%")
        #
        # # 单个样本预测
        # y_pred = knn.predict(DoubleUtil.str_2dlist_to_float_2dlist(KNN2Test.Single_Sample_Data))
        # accuracy = accuracy_score(y_pred, DoubleUtil.str_list_to_float_list(KNN2Test.Single_Sample_Label))
        # Logger.info("单个样本预测的类别为【：" + str(y_pred[0]) + "】，真实类别为【" + KNN2Test.Single_Sample_Label[0] + "】")


if __name__ == "__main__":
    knn2_test = KNN2Test()

    # 从文件中读取数据
    all_data_list = knn2_test.read_file(KNN2Test.File_Path)

    # 归一化
    normal_all_data_list = knn2_test.normalize(all_data_list)

    # 区分训练数据和测试数据
    normal_training_dataset, normal_training_dataset_label, normal_testing_dataset, normal_testing_dataset_label = knn2_test.differentiate_training_dataset_and_testing_dataset(
        normal_all_data_list)

    # 将str类型转换为float类型
    normal_training_dataset = DoubleUtil.str_2dlist_to_float_2dlist(normal_training_dataset)
    normal_training_dataset_label = DoubleUtil.str_list_to_float_list(normal_training_dataset_label)
    normal_testing_dataset = DoubleUtil.str_2dlist_to_float_2dlist(normal_testing_dataset)
    normal_testing_dataset_label = DoubleUtil.str_list_to_float_list(normal_testing_dataset_label)

    # 执行k近邻算法
    knn2_test.classify(normal_training_dataset, normal_training_dataset_label, normal_testing_dataset,
                       normal_testing_dataset_label)

    Logger.info("完成")
